"""
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
"""

import torch
from torch.nn import Module, Dropout


def elu_feature_map(x):
    return torch.nn.functional.elu(x) + 1


class LinearAttention(Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.feature_map = elu_feature_map
        self.eps = eps

    def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
        """ Multi-Head linear attention proposed in "Transformers are RNNs"
        Args:
            queries: [N, L, H, D]
            keys: [N, S, H, D]
            values: [N, S, H, D]
            q_mask: [N, L]
            kv_mask: [N, S]
        Returns:
            queried_values: (N, L, H, D)
        """
        Q = self.feature_map(queries)
        K = self.feature_map(keys)
        # print("q_mask:", q_mask)
        # print("kv_mask:", kv_mask)
        # set padded position to zero
        if q_mask is not None:
            Q = Q * q_mask[:, :, None, None]
        if kv_mask is not None:
            K = K * kv_mask[:, :, None, None]
            values = values * kv_mask[:, :, None, None]

        v_length = values.size(1)
        values = values / v_length  # prevent fp16 overflow

        # KV = torch.einsum("nshd, nshv->nhdv", K, values)  # (S,D)' @ S,V
        k_n, k_s, k_h, k_d = K.shape
        v_n, v_s, v_h, v_v = values.shape
        K1 = K.reshape(k_n, k_s, k_h, k_d, 1)
        V1 = values.reshape(v_n, v_s, v_h, 1, v_v)
        KV = (K1*V1).sum(1) # n,s,h,d,v->n,h,d,v
        
        
        # Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
        # torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))
        Ks = K.sum(dim=1, keepdim=True) # n,1,h,d  
        QK = (Q*Ks).sum(3)
        Z = 1 / (QK + self.eps)

        # queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
        qshape = Q.shape
        kvshape = KV.shape
        zshape = Z.shape
        Qr = Q.reshape(qshape[0], qshape[1], qshape[2], qshape[3], 1)
        KVr = KV.reshape(kvshape[0], 1, kvshape[1], kvshape[2], kvshape[3])
        Zr = Z.reshape(zshape[0], zshape[1], zshape[2], 1)
        queried_values = ((Qr * KVr).sum(3) * Zr) * v_length
        
        return queried_values.contiguous()


class FullAttention(Module):
    def __init__(self, use_dropout=False, attention_dropout=0.1):
        super().__init__()
        self.use_dropout = use_dropout
        self.dropout = Dropout(attention_dropout)

    def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
        """ Multi-head scaled dot-product attention, a.k.a full attention.
        Args:
            queries: [N, L, H, D]
            keys: [N, S, H, D]
            values: [N, S, H, D]
            q_mask: [N, L]
            kv_mask: [N, S]
        Returns:
            queried_values: (N, L, H, D)
        """

        # Compute the unnormalized attention and apply the masks
        # QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
        bs, bc, bh, bw = queries.shape
        ks, kc, kh, kw = keys.shape
        queries1 = queries.reshape(bs, bc, 1, bh, bw)
        keys1 = keys.reshape(ks, 1, kc, kh, kw)
        QK = (queries1 * keys1).sum(3)
        
        if kv_mask is not None:
            QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))

        # Compute the attention and the weighted average
        softmax_temp = 1. / queries.size(3)**.5  # sqrt(D)
        A = torch.softmax(softmax_temp * QK, dim=2)
        if self.use_dropout:
            A = self.dropout(A)

        queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)

        return queried_values.contiguous()
